# Spectral GCN + Attention Recovery + LSTM
# This code trains and tests the GNN model for the COVID-19 infection prediction in Tokyo
# Author: Jiawei Xue, August 26, 2021
# Step 1: read and pack the traning and testing data
# Step 2: training epoch, training process, testing
# Step 3: build the model = spectral GCN + Attention Recovery + LSTM
# Step 4: main function
# Step 5: evaluation
# Step 6: visualization
import os
import csv
import json
import copy
import time
import random
import string
import argparse
import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from spectral_T3_GCN_memory_light import SpecGCN
from spectral_T3_GCN_memory_light import SpecGCN_LSTM
#torch.set_printoptions(precision=8)
#hyperparameter for the setting
X_day, Y_day = 21,7
START_DATE, END_DATE = '20200414','20210207'
#START_DATE, END_DATE = '20200808','20210603'
#START_DATE, END_DATE = '20200720','20210515'
WINDOW_SIZE = 7
#hyperparameter for the learning
DROPOUT, ALPHA = 0.50, 0.20
NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE = 100, 8, 0.0001
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2 = 6,4,2
infection_normalize_ratio = 100.0
web_search_normalize_ratio = 100.0
train_ratio = 0.7
validate_ratio = 0.1
#1.total period (mobility+text):
#from 20200201 to 20210620: (29+31+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20)\
#= 335 + 171 = 506;
#2.number of zones: 23;
#3.infection period:
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#1. Mobility: functions 1.2 to 1.7
#2. Text: functions 1.8 to 1.14
#3. InfectionL: functions 1.15
#4. Preprocess: functions 1.16 to 1.24
#5. Learn: functions 1.25 to 1.26
#function 1.1
#get the central areas of Tokyo (e.g., the Special wards of Tokyo)
#return: a 23 zone shapefile
def read_tokyo_23():
folder = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/present_model_version10/tokyo_23"
file = "tokyo_23zones.shp"
path = os.path.join(folder,file)
data = gpd.read_file(path)
return data
##################1.Mobility#####################
#function 1.2
#get the average of two days' mobility (infection) records
def mob_inf_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = (record1[i]+record2[i])/2.0
return new_record
#function 1.3
#get the average of multiple days' mobility (infection) records
def mob_inf_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record:
if zone_id not in list(new_record.keys()):
new_record[zone_id] = record[zone_id]
else:
new_record[zone_id] += record[zone_id]
for new_record_key in new_record:
new_record[new_record_key] = new_record[new_record_key]*1.0/num_day
return new_record
#function 1.4
#generate the dateList: [20200101, 20200102, ..., 20211231]
def generate_dateList():
yearList = ["2020","2021"]
monthList = ["0"+str(i+1) for i in range(9)] + ["10","11","12"]
dayList = ["0"+str(i+1) for i in range(9)] + [str(i) for i in range(10,32)]
day_2020_num = [31,29,31,30,31,30,31,31,30,31,30,31]
day_2021_num = [31,28,31,30,31,30,31,31,30,31,30,31]
date_2020, date_2021 = list(), list()
for i in range(12):
for j in range(day_2020_num[i]):
date_2020.append(yearList[0] + monthList[i] + dayList[j])
for j in range(day_2021_num[i]):
date_2021.append(yearList[1] + monthList[i] + dayList[j])
date_2020_2021 = date_2020 + date_2021
return date_2020_2021
#function 1.5
#smooth the mobility (infection) data using the neighborhood average
#under a given window size
#dateList: [20200101, 20200102, ..., 20211231]
def mob_inf_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = mob_inf_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.6
#set the mobility (infection) of one day as zero
def mob_inf_average_null(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = 0
return new_record
#function 1.7
#read the mobility data from "mobility_feature_20200201.json"...
#return: all_mobility:{"20200201":{('123','123'):12345,...},...}
#20200201 to 20210620: 506 days
def read_mobility_data(jcode23):
all_mobility = dict()
mobilityFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/mobility_20210804"
mobilityNameList = os.listdir(mobilityFilePath)
for i in range(len(mobilityNameList)):
day_mobility = dict()
file_name = mobilityNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0] #get the day
file_path = mobilityFilePath + '/' + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
for key in df_file:
origin, dest = key.split("_")[0], key.split("_")[1]
if origin in jcode23 and dest in jcode23:
if origin == dest:
day_mobility[(origin, dest)] = 0.0 #ignore the inner-zone flow
else:
day_mobility[(origin, dest)] = df_file[key]
all_mobility[day] = day_mobility
#missing data
all_mobility["20201128"] = mob_inf_average(all_mobility,"20201127","20201129")
all_mobility["20210104"] = mob_inf_average(all_mobility, "20210103","20210105")
return all_mobility
##################2.Text#####################
#function 1.8
#get the average of two days' infection records
def text_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
zone_record1, zone_record2 = record1[i], record2[i]
new_zone_record = dict()
for j in zone_record1:
if j in zone_record2:
new_zone_record[j] = (zone_record1[j] + zone_record2[j])/2.0
new_record[i] = new_zone_record
return new_record
#function 1.9
#get the average of multiple days' text records
def text_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record: #zone_id
if zone_id not in new_record:
new_record[zone_id] = dict()
for j in record[zone_id]: #symptom
if j not in new_record[zone_id]:
new_record[zone_id][j] = record[zone_id][j]
else:
new_record[zone_id][j] += record[zone_id][j]
for zone_id in new_record:
for j in new_record[zone_id]:
new_record[zone_id][j] = new_record[zone_id][j]*1.0/num_day
return new_record
#function 1.10
#smooth the text data using the neighborhood average
#under a given window size
def text_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = text_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.11
#read the number of user points
def read_point_json():
with open('user_point/mobility_user_point.json') as point1:
user_point1 = json.load(point1)
with open('user_point/mobility_user_point_20210812.json') as point2:
user_point2 = json.load(point2)
user_point_all = dict()
for i in user_point1:
user_point_all[i] = user_point1[i]
for i in user_point2:
user_point_all[i] = user_point2[i]
user_point_all["20201128"] = user_point_all["20201127"] #data missing
user_point_all["20210104"] = user_point_all["20210103"] #data missing
return user_point_all
#function 1.12
#normalize the text search by the number of user points.
def normalize_text_user(all_text, user_point_all):
for day in all_text:
if day in user_point_all:
num_user = user_point_all[day]["num_user"]
all_text_day_new = dict()
all_text_day = all_text[day]
for zone in all_text_day:
if zone not in all_text_day_new:
all_text_day_new[zone] = dict()
for sym in all_text_day[zone]:
all_text_day_new[zone][sym] = all_text_day[zone][sym]*1.0/num_user
all_text[day] = all_text_day_new
return all_text
#function 1.13
#read the text data
#20200201 to 20210620: 506 days
#all_text = {"20200211":{"123":{"code":3,"fever":2,...},...},...}
def read_text_data(jcode23):
all_text = dict()
textFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/text_20210804"
textNameList = os.listdir(textFilePath)
for i in range(len(textNameList)):
day_text = dict()
file_name = textNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0]
file_path = textFilePath + "/" + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
new_dict = dict()
for key in df_file:
if key in jcode23:
new_dict[key] = {key1:df_file[key][key1]*1.0*web_search_normalize_ratio for key1 in df_file[key]}
#new_dict[key] = df_file[key]*WEB_SEARCH_RATIO
all_text[day] = new_dict
all_text["20201030"] = text_average(all_text, "20201029", "20201031") #data missing
return all_text
#function 1.14
#perform the min-max normalization for the text data.
def min_max_text_data(all_text,jcode23):
#calculate the min_max
#region_key: sym: [min,max]
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい', '吐き気', '嘔吐', '筋肉痛', '動悸', \
'副鼻腔炎', '発疹', 'くしゃみ', '倦怠感', '寒気', '脱水', \
'中咽頭', '関節痛', '不眠症', '睡眠障害', '鼻漏', '片頭痛', \
'多汗症', 'ほてり', '胸痛', '発汗', '無気力', '呼吸困難', \
'喘鳴', '目の痛み', '体の痛み', '無嗅覚症', '耳の痛み', \
'錯乱', '見当識障害', '胸の圧迫感', '鼻の乾燥', '耳感染症', \
'味覚消失', '上気道感染症', '眼感染症', '食欲減少'])
region_sym_min_max = dict()
for key in jcode23: #initialize
region_sym_min_max[key] = dict()
for sym in text_list:
region_sym_min_max[key][sym] = [1000000,0] #min, max
for day in all_text: #update
for key in jcode23:
for sym in text_list:
if sym in all_text[day][key]:
count = all_text[day][key][sym]
if count < region_sym_min_max[key][sym][0]:
region_sym_min_max[key][sym][0] = count
if count > region_sym_min_max[key][sym][1]:
region_sym_min_max[key][sym][1] = count
#print ("region_sym_min_max",region_sym_min_max)
for key in jcode23: #normalize
for sym in text_list:
min_count,max_count=region_sym_min_max[key][sym][0],region_sym_min_max[key][sym][1]
for day in all_text:
if sym in all_text[day][key]:
if max_count-min_count == 0:
all_text[day][key][sym] = 1
else:
all_text[day][key][sym] = (all_text[day][key][sym]-min_count)*1.0/(max_count-min_count)
#print("all_text[day][key][sym]",all_text[day][key][sym])
return all_text
##################3.Infection#####################
#function 1.15
#read the infection data
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#all_infection = {"20200201":{"123":1,"123":2}}
def read_infection_data(jcode23):
all_infection = dict()
infection_path = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/patient_20210725.json"
f = open(infection_path,)
df_file = json.load(f) #read the mobility file
f.close()
for zone_id in df_file:
for one_day in df_file[zone_id]:
daySplit = one_day.split("/")
year, month, day = daySplit[0], daySplit[1], daySplit[2]
if len(month) == 1:
month = "0" + month
if len(day) == 1:
day = "0" + day
new_date = year + month + day
if str(zone_id[0:5]) in jcode23:
if new_date not in all_infection:
all_infection[new_date] = {zone_id[0:5]:df_file[zone_id][one_day]*1.0/infection_normalize_ratio}
else:
all_infection[new_date][zone_id[0:5]] = df_file[zone_id][one_day]*1.0/infection_normalize_ratio
#missing
date_list = [str(20200316+i) for i in range(15)]
for date in date_list:
all_infection[date] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200514'] = mob_inf_average(all_infection,'20200513','20200515')
all_infection['20200519'] = mob_inf_average(all_infection,'20200518','20200520')
all_infection['20200523'] = mob_inf_average(all_infection,'20200522','20200524')
all_infection['20200530'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20200531'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20201231'] = mob_inf_average(all_infection,'20201230','20210101')
all_infection['20210611'] = mob_inf_average(all_infection,'20210610','20210612')
#outlier
all_infection['20200331'] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200910'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200911'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200511'] = mob_inf_average(all_infection,'20200510','20200512')
all_infection['20201208'] = mob_inf_average(all_infection,'20201207','20201209')
all_infection['20210208'] = mob_inf_average(all_infection,'20210207','20210209')
all_infection['20210214'] = mob_inf_average(all_infection,'20210213','20210215')
#calculate the subtraction
all_infection_subtraction = dict()
all_infection_subtraction['20200331'] = all_infection['20200331']
all_keys = list(all_infection.keys())
all_keys.sort()
for i in range(len(all_keys)-1):
record = dict()
for j in all_infection[all_keys[i+1]]:
record[j] = all_infection[all_keys[i+1]][j] - all_infection[all_keys[i]][j]
all_infection_subtraction[all_keys[i+1]] = record
return all_infection_subtraction, all_infection
##################4.Preprocess#####################
#function 1.16
#ensemble the mobility, text, and infection.
#all_mobility = {"20200201":{('123','123'):12345,...},...}
#all_text = {"20200201":{"123":{"cold":3,"fever":2,...},...},...}
#all_infection = {"20200316":{"123":1,"123":2}}
#all_x_y = {"0":[[mobility_1,text_1, ..., mobility_x_day,text_x_day], [infection_1,...,infection_y_day],\
#[infection_1,...,infection_x_day]],0}
#x_days, y_days: use x_days to predict y_days
def ensemble(all_mobility, all_text, all_infection, x_days, y_days, all_day_list):
all_x_y = dict()
for j in range(len(all_day_list) - x_days - y_days + 1):
x_sample, y_sample, x_sample_infection = list(), list(), list()
#add the data from all_day_list[0+j] to all_day_list[x_days-1+j]
for k in range(x_days):
day = all_day_list[k + j]
x_sample.append(all_mobility[day])
x_sample.append(all_text[day])
x_sample_infection.append(all_infection[day]) #concatenate with the infection data
#add the data from all_day_list[x_days+j] to all_day_list[x_days+y_day-1+j]
for k in range(y_days):
day = all_day_list[x_days + k + j]
y_sample.append(all_infection[day])
all_x_y[str(j)] = [x_sample, y_sample, x_sample_infection,j]
return all_x_y
#function 1.17
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_key = [all_x_y[str(i)] for i in range(n_train)]
validate_key = [all_x_y[str(i+n_train)] for i in range(n_validate)]
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key
##function 1.18
#the second data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_2(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train+n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if i % 9 == 8:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.19
#the third data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_3(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n - n_train - n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train + n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if (n_train + n_validate-i) % 2 == 0 and (n_train + n_validate-i) <= 2*n_validate:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.20
#find the mobility data starting from the day, which is x_days before the start_date
#start_date = "20200331", x_days = 7
def sort_date(all_mobility, start_date, x_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
mobility_date_cut = mobility_date_list[idx-x_days:]
return mobility_date_cut
#function 1.21
#find the mobility data starting from the day, which is x_days before the start_date,
#ending at the day, which is y_days after the end_date
#start_date = "20200331", x_days = 7
def sort_date_2(all_mobility, start_date, x_days, end_date, y_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
idx2 = mobility_date_list.index(end_date)
mobility_date_cut = mobility_date_list[idx-x_days:idx2+y_days]
return mobility_date_cut
#function 1.22
#get the mappings from zone id to id, text id to id.
#get zone_text_to_idx
def get_zone_text_to_idx(all_infection):
zone_list = list(set(all_infection["20200401"].keys()))
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい'])
zone_list.sort()
zone_dict = {str(zone_list[i]):i for i in range(len(zone_list))}
text_dict = {str(text_list[i]):i for i in range(len(text_list))}
return zone_dict, text_dict
#function 1.23
#change the data format to matrix
#zoneid_to_idx = {"13101":0, "13102":1, ..., "13102":22}
#sym_to_idx = {"cough":0}
#mobility: {('13101', '13101'): 709973, ...}
#text: {'13101': {'痛み': 51,...},...} text
#infection: {'13101': 50, '13102': 137, '13103': 401,...}
#data_type = {"mobility", "text", "infection"}
def to_matrix(zoneid_to_idx, sym_to_idx, input_data, data_type):
n_zone, n_text = len(zoneid_to_idx), len(sym_to_idx)
if data_type == "mobility":
result = np.zeros((n_zone, n_zone))
for key in input_data:
from_id, to_id = key[0], key[1]
from_idx, to_idx = zoneid_to_idx[from_id], zoneid_to_idx[to_id]
result[from_idx][to_idx] += input_data[key]
if data_type == "text":
result = np.zeros((n_zone, n_text))
for key1 in input_data:
for key2 in input_data[key1]:
if key1 in list(zoneid_to_idx.keys()) and key2 in list(sym_to_idx.keys()):
zone_idx, text_idx = zoneid_to_idx[key1], sym_to_idx[key2]
result[zone_idx][text_idx] += input_data[key1][key2]
if data_type == "infection":
result = np.zeros(n_zone)
for key in input_data:
zone_idx = zoneid_to_idx[key]
result[zone_idx] += input_data[key]
return result
#function 1.24
#change the data to the matrix format
def change_to_matrix(data, zoneid_to_idx, sym_to_idx):
data_result = list()
for i in range(len(data)):
combine1, combine2 = list(), list()
combine3 = list() #NEW
mobility_text = data[i][0]
x_infection_all = data[i][2] #the x_days infection data
day_order = data[i][3] #NEW the order of the day
for j in range(round(len(mobility_text)*1.0/2)):
mobility, text = mobility_text[2*j], mobility_text[2*j+1]
x_infection = x_infection_all[j] #NEW
new_mobility = to_matrix(zoneid_to_idx, sym_to_idx, mobility, "mobility")
new_text = to_matrix(zoneid_to_idx, sym_to_idx, text, "text")
combine1.append(new_mobility)
combine1.append(new_text)
new_x_infection = to_matrix(zoneid_to_idx, sym_to_idx, x_infection, "infection") #NEW
combine3.append(new_x_infection) #NEW
for j in range(len(data[i][1])):
infection = data[i][1][j]
new_infection = to_matrix(zoneid_to_idx, sym_to_idx, infection, "infection")
combine2.append(new_infection)
data_result.append([combine1,combine2,combine3,day_order]) #mobility/text; infection_y; infection_x; day_order
return data_result
##################5.learn#####################
#function 1.25
def visual_loss(e_losses, vali_loss, test_loss):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1,y2,y3 = copy.copy(e_losses), copy.copy(vali_loss), copy.copy(test_loss)
plt.plot(x,y1,linewidth=1, label="train")
plt.plot(x,y2,linewidth=1, label="validate")
plt.plot(x,y3,linewidth=1, label="test")
plt.legend()
plt.title('Loss decline on entire training/validation/testing data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 1.26
def visual_loss_train(e_losses):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1 = copy.copy(e_losses)
plt.plot(x,y1,linewidth=1, label="train")
plt.legend()
plt.title('Loss decline on entire training data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 2.1
#normalize each column of the input mobility matrix as one
def normalize_column_one(input_matrix):
column_sum = np.sum(input_matrix, axis=0)
row_num, column_num = len(input_matrix), len(input_matrix[0])
for i in range(row_num):
for j in range(column_num):
input_matrix[i][j] = input_matrix[i][j]*1.0/column_sum[j]
return input_matrix
#function 2.2
#evalute the trained_model on validation or testing data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x)
loss = criterion(y_hat.float(), y_real.float()) ###Calculate the loss
return loss, y_hat, y_real
#function 2.3
#convert the mobility matrix in x_batch in a following way
#normalize the flow between zones so that the in-flow of each zone is 1.
def convertAdj(x_batch):
#x_batch:(n_batch, 0/1, 2*i+1)
x_batch_new = copy.copy(x_batch)
n_batch = len(x_batch)
days = round(len(x_batch[0][0])/2)
for i in range(n_batch):
for j in range(days):
mobility_matrix = x_batch[i][0][2*j]
x_batch_new[i][0][2*j] = normalize_column_one(mobility_matrix) #20210818
return x_batch_new
#function 2.4
#a training epoch
def train_epoch_option(model, opt, criterion, trainX_c, trainY_c, batch_size):
model.train()
losses = []
batch_num = 0
for beg_i in range(0, len(trainX_c), batch_size):
batch_num += 1
if batch_num % 16 ==0:
print ("batch_num: ", batch_num, "total batch number: ", int(len(trainX_c)/batch_size))
x_batch = trainX_c[beg_i:beg_i+batch_size]
y_batch = torch.tensor(trainY_c[beg_i:beg_i+batch_size])
opt.zero_grad()
x_batch = convertAdj(x_batch) #conduct the column normalization
y_hat = model.run_specGCN_lstm(x_batch) ###Attention
loss = criterion(y_hat.float(), y_batch.float()) #MSE loss
#opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.data.numpy())
return sum(losses)/float(len(losses)), model
#function 2.5
#multiple training epoch
def train_process(train_data, lr, num_epochs, net, criterion, bs, vali_data, test_data):
opt = optim.Adam(net.parameters(), lr, betas = (0.9,0.999), weight_decay=0)
train_y = [train_data[i][1] for i in range(len(train_data))]
e_losses = list()
e_losses_vali = list()
e_losses_test = list()
time00 = time.time()
for e in range(num_epochs):
time1 = time.time()
print ("current epoch: ",e, "total epoch: ", num_epochs)
number_list = list(range(len(train_data)))
random.shuffle(number_list)
trainX_sample = [train_data[number_list[j]] for j in range(len(number_list))]
trainY_sample = [train_y[number_list[j]] for j in range(len(number_list))]
loss, net = train_epoch_option(net, opt, criterion, trainX_sample, trainY_sample, bs)
print ("train loss", loss*infection_normalize_ratio*infection_normalize_ratio)
e_losses.append(loss*infection_normalize_ratio*infection_normalize_ratio)
loss_vali, y_hat_vali, y_real_vali = validate_test_process(net, vali_data)
loss_test, y_hat_test, y_real_test = validate_test_process(net, test_data)
e_losses_vali.append(float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
e_losses_test.append(float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
print ("validate loss", float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
print ("test loss", float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
if e>=2 and (e+1)%10 ==0:
visual_loss(e_losses, e_losses_vali, e_losses_test)
visual_loss_train(e_losses)
time2 = time.time()
print ("running time for this epoch:", time2 - time1)
time01 = time.time()
print ("---------------------------------------------------------------")
print ("---------------------------------------------------------------")
#print ("total running time until now:", time01 - time00)
#print ("------------------------------------------------")
#print("specGCN_weight", net.specGCN.layer1.W)
#print("specGCN_weight_grad", net.specGCN.layer1.W.grad)
#print ("------------------------------------------------")
#print("memory decay matrix", net.v)
#print("memory decay matrix grad", net.v.grad)
#print ("------------------------------------------------")
#print ("lstm weight", net.lstm.all_weights[0][0])
#print ("lstm weight grad", net.lstm.all_weights[0][0].grad)
#print ("------------------------------------------------")
#print ("fc1.weight", net.fc1.weight)
#print ("fc1 weight grd", net.fc1.weight.grad)
#print ("---------------------------------------------------------------")
#print ("---------------------------------------------------------------")
return e_losses, net
#function 3.1
def read_data():
jcode23 = list(read_tokyo_23()["JCODE"]) #1.1 get the tokyo 23 zone shapefile
all_mobility = read_mobility_data(jcode23) #1.2 read the mobility data
all_text = read_text_data(jcode23) #1.3 read the text data
all_infection, all_infection_cum = read_infection_data(jcode23) #1.4 read the infection data
#smooth the data using 7-days average
window_size = WINDOW_SIZE #20210818
dateList = generate_dateList() #20210818
all_mobility = mob_inf_smooth(all_mobility, window_size, dateList) #20210818
all_infection = mob_inf_smooth(all_infection, window_size, dateList) #20210818
#smooth, user, min-max.
point_json = read_point_json() #20210821
all_text = normalize_text_user(all_text, point_json) #20210821
all_text = text_smooth(all_text, window_size, dateList) #20210818
all_text = min_max_text_data(all_text,jcode23) #20210820
x_days, y_days = X_day, Y_day
mobility_date_cut = sort_date_2(all_mobility, START_DATE, x_days, END_DATE, y_days)
all_x_y = ensemble(all_mobility, all_text, all_infection, x_days, y_days, mobility_date_cut)
train_original, validate_original, test_original, train_list, validation_list =\
split_data_3(all_x_y,train_ratio,validate_ratio)
zone_dict, text_dict = get_zone_text_to_idx(all_infection) #get zone_idx, text_idx
train_x_y = change_to_matrix(train_original, zone_dict, text_dict) #get train
print ("train_x_y_shape",len(train_x_y),"train_x_y_shape[0]",len(train_x_y[0]))
validate_x_y = change_to_matrix(validate_original, zone_dict, text_dict) #get validate
test_x_y = change_to_matrix(test_original, zone_dict, text_dict) #get test
print (len(train_x_y)) #300
print (len(train_x_y[0][0])) #14
print (np.shape(train_x_y[0][0][0])) #(23,23)
print (np.shape(train_x_y[0][0][1])) #(23,43)
#print ("---------------------------------finish data reading and preprocessing------------------------------------")
return train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, train_original, validate_original, test_original, train_list, validation_list
#function 3.2
#train the model
def model_train(train_x_y, vali_data, test_data):
#3.2.1 define the model
input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2 = len(train_x_y[0][0][1][1]),\
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2
dropout_1, alpha_1, N = DROPOUT, ALPHA, len(train_x_y[0][0][1])
G_L_Model = SpecGCN_LSTM(X_day, Y_day, input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2, dropout_1,N) ###Attention
#3.2.2 train the model
num_epochs, batch_size, learning_rate = NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE #model train
criterion = nn.MSELoss()
e_losses, trained_model = train_process(train_x_y, learning_rate, num_epochs, G_L_Model, criterion, batch_size,\
vali_data, test_data)
return e_losses, trained_model
#function 3.3
#evaluate the error on validation (or testing) data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x) ###Attention
loss = criterion(y_hat.float(), y_real.float())
return loss, y_hat, y_real
#4.1
#read the data
train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, \
train_original, validate_original, test_original, train_list, validation_list =\
read_data()
#train_x_y, validate_x_y, test_x_y = normalize(train_x_y, validate_x_y, test_x_y)
#train_x_y = train_x_y[0:30]
print (len(train_x_y))
print ("---------------------------------finish data preparation------------------------------------")
train_x_y_shape 210 train_x_y_shape[0] 4 210 42 (23, 23) (23, 8) 210 ---------------------------------finish data preparation------------------------------------
#4.2
#train the model
e_losses, trained_model = model_train(train_x_y, validate_x_y, test_x_y)
print ("---------------------------finish model training-------------------------")
current epoch: 0 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.647397689245366 validate loss 36.76799591630697 test loss 779.8657566308975 running time for this epoch: 8.12667465209961 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 1 total epoch: 100 batch_num: 16 total batch number: 26 train loss 44.47059766216962 validate loss 36.711087450385094 test loss 783.8582247495651 running time for this epoch: 6.620002508163452 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 2 total epoch: 100 batch_num: 16 total batch number: 26 train loss 41.292261728947906 validate loss 36.07966238632798 test loss 771.656259894371 running time for this epoch: 6.751824855804443 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 3 total epoch: 100 batch_num: 16 total batch number: 26 train loss 39.131913972259674 validate loss 35.63792444765568 test loss 761.7506384849548 running time for this epoch: 6.786649942398071 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 4 total epoch: 100 batch_num: 16 total batch number: 26 train loss 37.313941348758014 validate loss 35.17717123031616 test loss 749.8045265674591 running time for this epoch: 6.71968674659729 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 5 total epoch: 100 batch_num: 16 total batch number: 26 train loss 35.681171446210804 validate loss 34.0745784342289 test loss 725.1906394958496 running time for this epoch: 6.7422003746032715 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 6 total epoch: 100 batch_num: 16 total batch number: 26 train loss 33.97908441170498 validate loss 33.54320069774985 test loss 710.6523215770721 running time for this epoch: 6.8280415534973145 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 7 total epoch: 100 batch_num: 16 total batch number: 26 train loss 32.178795879223834 validate loss 32.638488337397575 test loss 689.3042474985123 running time for this epoch: 6.793251276016235 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 8 total epoch: 100 batch_num: 16 total batch number: 26 train loss 30.10717099877419 validate loss 31.633616890758276 test loss 666.4233654737473 running time for this epoch: 6.789132356643677 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 9 total epoch: 100 batch_num: 16 total batch number: 26 train loss 28.749005696563806 validate loss 30.82163631916046 test loss 646.4817374944687
running time for this epoch: 7.110028982162476 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 10 total epoch: 100 batch_num: 16 total batch number: 26 train loss 27.97419666657569 validate loss 29.907873831689358 test loss 625.5150586366653 running time for this epoch: 6.7917375564575195 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 11 total epoch: 100 batch_num: 16 total batch number: 26 train loss 26.20328709276186 validate loss 28.715201187878847 test loss 600.4509329795837 running time for this epoch: 6.84668231010437 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 12 total epoch: 100 batch_num: 16 total batch number: 26 train loss 25.574453433768618 validate loss 27.79899165034294 test loss 580.5317685008049 running time for this epoch: 6.907538652420044 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 13 total epoch: 100 batch_num: 16 total batch number: 26 train loss 24.810843212805963 validate loss 26.945918798446655 test loss 561.9990825653076 running time for this epoch: 6.878588438034058 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 14 total epoch: 100 batch_num: 16 total batch number: 26 train loss 23.681677800292768 validate loss 26.205838657915592 test loss 545.4304441809654 running time for this epoch: 6.898064374923706 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 15 total epoch: 100 batch_num: 16 total batch number: 26 train loss 22.202074113819332 validate loss 25.24112816900015 test loss 527.0727723836899 running time for this epoch: 6.929305076599121 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 16 total epoch: 100 batch_num: 16 total batch number: 26 train loss 21.793596894928704 validate loss 24.88986821845174 test loss 516.527071595192 running time for this epoch: 6.900653839111328 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 17 total epoch: 100 batch_num: 16 total batch number: 26 train loss 20.98442689101729 validate loss 24.272084701806307 test loss 503.46627831459045 running time for this epoch: 6.913877248764038 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 18 total epoch: 100 batch_num: 16 total batch number: 26 train loss 20.6372427702364 validate loss 23.360655177384615 test loss 487.34504729509354 running time for this epoch: 6.976437330245972 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 19 total epoch: 100 batch_num: 16 total batch number: 26 train loss 19.31241036530722 validate loss 22.578490898013115 test loss 473.8248139619827
running time for this epoch: 7.180236101150513 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 20 total epoch: 100 batch_num: 16 total batch number: 26 train loss 19.36136861331761 validate loss 22.267946042120457 test loss 465.5690863728523 running time for this epoch: 6.906867265701294 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 21 total epoch: 100 batch_num: 16 total batch number: 26 train loss 18.143456308516086 validate loss 21.715639159083366 test loss 455.22280037403107 running time for this epoch: 6.999639272689819 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 22 total epoch: 100 batch_num: 16 total batch number: 26 train loss 18.210209112752366 validate loss 21.582150366157293 test loss 449.21983033418655 running time for this epoch: 6.904180288314819 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 23 total epoch: 100 batch_num: 16 total batch number: 26 train loss 17.19670253805816 validate loss 21.00336365401745 test loss 439.75580483675003 running time for this epoch: 6.940634250640869 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 24 total epoch: 100 batch_num: 16 total batch number: 26 train loss 16.8379841596578 validate loss 20.570761989802122 test loss 431.8009689450264 running time for this epoch: 7.000272512435913 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 25 total epoch: 100 batch_num: 16 total batch number: 26 train loss 16.884558523694675 validate loss 20.280384924262762 test loss 425.34783482551575 running time for this epoch: 6.954397678375244 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 26 total epoch: 100 batch_num: 16 total batch number: 26 train loss 16.02241819017325 validate loss 20.18859377130866 test loss 420.71953415870667 running time for this epoch: 6.980032444000244 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 27 total epoch: 100 batch_num: 16 total batch number: 26 train loss 16.220644961490674 validate loss 20.017235074192286 test loss 415.7097637653351 running time for this epoch: 7.009483337402344 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 28 total epoch: 100 batch_num: 16 total batch number: 26 train loss 15.75170697092458 validate loss 19.231971818953753 test loss 406.55624121427536 running time for this epoch: 6.945741891860962 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 29 total epoch: 100 batch_num: 16 total batch number: 26 train loss 15.318481536168191 validate loss 18.60034535638988 test loss 398.9987075328827
running time for this epoch: 7.301460266113281 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 30 total epoch: 100 batch_num: 16 total batch number: 26 train loss 15.064961992687099 validate loss 18.346620490774512 test loss 394.29444819688797 running time for this epoch: 7.019617080688477 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 31 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.658808660762453 validate loss 18.13243143260479 test loss 389.9414837360382 running time for this epoch: 6.950105667114258 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 32 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.575986424461007 validate loss 18.14987976104021 test loss 386.95797324180603 running time for this epoch: 6.998233318328857 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 33 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.130337092663265 validate loss 17.756150336936116 test loss 381.8707540631294 running time for this epoch: 7.02725076675415 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 34 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.277123070011537 validate loss 17.414442263543606 test loss 377.45125591754913 running time for this epoch: 6.984612703323364 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 35 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.111850070077237 validate loss 17.473777988925576 test loss 374.98507648706436 running time for this epoch: 6.952573299407959 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 36 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.834612435419801 validate loss 17.202369635924697 test loss 371.0859641432762 running time for this epoch: 7.007890701293945 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 37 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.5444904487856 validate loss 16.67572185397148 test loss 366.2221133708954 running time for this epoch: 6.955915451049805 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 38 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.442839263007045 validate loss 16.606400022283196 test loss 363.5178506374359 running time for this epoch: 6.985631465911865 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 39 total epoch: 100 batch_num: 16 total batch number: 26 train loss 14.163953342681959 validate loss 16.45704614929855 test loss 360.3331744670868
running time for this epoch: 7.31410813331604 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 40 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.55343108082673 validate loss 16.168986912816763 test loss 356.75618797540665 running time for this epoch: 6.9773194789886475 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 41 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.856759954554338 validate loss 15.817041276022792 test loss 353.23135554790497 running time for this epoch: 6.98752236366272 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 42 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.933291124041984 validate loss 15.483888564631343 test loss 349.91927444934845 running time for this epoch: 7.070579767227173 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 43 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.884413245720443 validate loss 15.246807597577572 test loss 346.88685089349747 running time for this epoch: 6.991429567337036 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 44 total epoch: 100 batch_num: 16 total batch number: 26 train loss 13.163388831782397 validate loss 15.378454700112343 test loss 344.91460770368576 running time for this epoch: 6.986312627792358 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 45 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.697958219279018 validate loss 15.246138209477067 test loss 342.17149019241333 running time for this epoch: 7.069708824157715 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 46 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.589067380219973 validate loss 15.267283888533711 test loss 339.96667712926865 running time for this epoch: 6.947610139846802 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 47 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.736945408741358 validate loss 15.12679853476584 test loss 337.4648466706276 running time for this epoch: 7.007216215133667 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 48 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.260839511226449 validate loss 14.903086703270674 test loss 334.70433205366135 running time for this epoch: 7.0238940715789795 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 49 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.546558519480405 validate loss 14.770057750865817 test loss 332.38738775253296
running time for this epoch: 7.3031532764434814 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 50 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.05993831347398 validate loss 14.78646183386445 test loss 330.149382352829 running time for this epoch: 6.995750188827515 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 51 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.3206913014184 validate loss 14.534086221829057 test loss 327.6697173714638 running time for this epoch: 7.064112186431885 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 52 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.194061807046335 validate loss 14.177158009260893 test loss 324.91061836481094 running time for this epoch: 6.989722967147827 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 53 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.069645648201305 validate loss 14.266797807067633 test loss 322.80176877975464 running time for this epoch: 6.996906757354736 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 54 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.001265799712943 validate loss 14.062733389437199 test loss 320.4762563109398 running time for this epoch: 7.045665264129639 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 55 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.745815066172295 validate loss 13.993303291499615 test loss 318.3876723051071 running time for this epoch: 6.978229284286499 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 56 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.648549267125349 validate loss 13.93696409650147 test loss 316.2427246570587 running time for this epoch: 7.0390944480896 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 57 total epoch: 100 batch_num: 16 total batch number: 26 train loss 12.084807234781762 validate loss 13.753243256360292 test loss 314.0381723642349 running time for this epoch: 7.061800003051758 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 58 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.575554920084498 validate loss 13.812739634886384 test loss 311.96100637316704 running time for this epoch: 7.010165452957153 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 59 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.629548281672651 validate loss 13.678824761882424 test loss 309.8457306623459
running time for this epoch: 7.336317539215088 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 60 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.542559045159981 validate loss 13.6338802985847 test loss 307.9245425760746 running time for this epoch: 7.0457282066345215 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 61 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.725637358958247 validate loss 13.543633976951241 test loss 305.727981030941 running time for this epoch: 6.964476585388184 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 62 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.179975764308539 validate loss 13.282789150252938 test loss 303.49187552928925 running time for this epoch: 7.012770891189575 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 63 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.310807334397126 validate loss 13.258199905976653 test loss 301.72141268849373 running time for this epoch: 7.041741371154785 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 64 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.291921039594822 validate loss 13.208762975409627 test loss 299.9316528439522 running time for this epoch: 6.97866153717041 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 65 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.060125382047975 validate loss 13.063502265140414 test loss 297.8227101266384 running time for this epoch: 7.010528087615967 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 66 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.189838685957646 validate loss 12.960681924596429 test loss 295.9265559911728 running time for this epoch: 7.059494733810425 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 67 total epoch: 100 batch_num: 16 total batch number: 26 train loss 11.103042418620102 validate loss 13.225736329331994 test loss 294.34988275170326 running time for this epoch: 7.002898216247559 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 68 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.782265884767254 validate loss 12.735237833112478 test loss 292.06834733486176 running time for this epoch: 7.068879127502441 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 69 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.64103854501068 validate loss 12.756052892655134 test loss 290.3417684137821
running time for this epoch: 7.3633270263671875 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 70 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.832103909234757 validate loss 12.629671255126595 test loss 288.5798364877701 running time for this epoch: 7.037065029144287 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 71 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.82260515195904 validate loss 12.617232277989388 test loss 286.92223131656647 running time for this epoch: 7.078538417816162 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 72 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.699081596814924 validate loss 12.798871612176299 test loss 285.26704758405685 running time for this epoch: 7.08910870552063 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 73 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.594160843606074 validate loss 12.555724242702127 test loss 283.32507237792015 running time for this epoch: 7.006057500839233 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 74 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.551485013113254 validate loss 12.438645353540778 test loss 281.5677411854267 running time for this epoch: 7.051522731781006 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 75 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.537780827790913 validate loss 12.35160743817687 test loss 280.0261974334717 running time for this epoch: 7.066650390625 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 76 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.405735595634692 validate loss 12.358927633613348 test loss 278.47621589899063 running time for this epoch: 7.0091166496276855 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 77 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.561758762277249 validate loss 12.232012813910842 test loss 276.798065751791 running time for this epoch: 7.001652717590332 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 78 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.606059157806966 validate loss 12.150653637945652 test loss 275.0801667571068 running time for this epoch: 7.047005653381348 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 79 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.435451188607626 validate loss 11.975122615695 test loss 273.6184559762478
running time for this epoch: 7.3061816692352295 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 80 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.330464251132474 validate loss 11.950434418395162 test loss 271.8452177941799 running time for this epoch: 7.049691915512085 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 81 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.132934538112885 validate loss 11.850945884361863 test loss 270.3285962343216 running time for this epoch: 7.274144172668457 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 82 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.446710641392404 validate loss 11.92716066725552 test loss 268.9814008772373 running time for this epoch: 7.01380467414856 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 83 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.350879963004477 validate loss 11.754168663173914 test loss 267.4735337495804 running time for this epoch: 7.032395839691162 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 84 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.030975056536219 validate loss 11.595867108553648 test loss 265.87381958961487 running time for this epoch: 7.093613624572754 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 85 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.943128884683741 validate loss 11.690319515764713 test loss 264.4519880414009 running time for this epoch: 7.031843185424805 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 86 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.980472861099297 validate loss 11.501371627673507 test loss 263.1009742617607 running time for this epoch: 7.084315061569214 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 87 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.891576116421707 validate loss 11.728026438504457 test loss 261.7672272026539 running time for this epoch: 7.099881410598755 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 88 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.988568425695929 validate loss 11.363942176103592 test loss 260.18455624580383 running time for this epoch: 6.9957075119018555 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 89 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.87697254728388 validate loss 11.377409100532532 test loss 258.9043416082859
running time for this epoch: 7.319468259811401 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 90 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.226704676083667 validate loss 11.230806121602654 test loss 257.6513774693012 running time for this epoch: 7.090209722518921 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 91 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.872703877691594 validate loss 11.235128622502089 test loss 256.155151873827 running time for this epoch: 7.053301095962524 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 92 total epoch: 100 batch_num: 16 total batch number: 26 train loss 10.339049277482209 validate loss 11.419681832194328 test loss 254.83908131718636 running time for this epoch: 7.056804418563843 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 93 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.823020114304706 validate loss 11.023076949641109 test loss 253.42049077153206 running time for this epoch: 7.053930759429932 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 94 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.652738248136032 validate loss 10.963658569380641 test loss 252.08795443177223 running time for this epoch: 6.982371091842651 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 95 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.458702503858754 validate loss 11.029522866010666 test loss 250.90735405683517 running time for this epoch: 7.0448644161224365 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 96 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.571837058670267 validate loss 11.06688054278493 test loss 249.78891015052795 running time for this epoch: 7.091795921325684 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 97 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.956870366025854 validate loss 11.120014823973179 test loss 248.5351450741291 running time for this epoch: 7.024412155151367 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 98 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.847439381431926 validate loss 11.136039393022656 test loss 247.3459206521511 running time for this epoch: 7.00651741027832 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 99 total epoch: 100 batch_num: 16 total batch number: 26 train loss 9.50876511288462 validate loss 10.860481997951865 test loss 245.92574685811996
running time for this epoch: 7.382112979888916 --------------------------------------------------------------- --------------------------------------------------------------- ---------------------------finish model training-------------------------
#4.3
print (len(train_x_y))
print (len(validate_x_y))
print (len(test_x_y))
#4.3.1 model validation
validation_result, validate_hat, validate_real = validate_test_process(trained_model, validate_x_y)
print ("---------------------------------finish model validation------------------------------------")
print (len(validate_hat))
print (len(validate_real))
#4.3.2 model testing
#4.4. model test
test_result, test_hat, test_real = validate_test_process(trained_model, test_x_y)
print ("---------------------------------finish model testing------------------------------------")
print (len(test_real))
print (len(test_hat))
210 30 60 ---------------------------------finish model validation------------------------------------ 30 30 ---------------------------------finish model testing------------------------------------ 60 60
#5.1 RMSE, MAPE, MAE, RMSLE
def RMSELoss(yhat,y):
return float(torch.sqrt(torch.mean((yhat-y)**2)))
def MAPELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), y)))
def MAELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), 1)))
def RMSLELoss(yhat,y):
log_yhat = torch.log(yhat+1)
log_y = torch.log(y+1)
return float(torch.sqrt(torch.mean((log_yhat-log_y)**2)))
#compute RMSE
rmse_validate = list()
rmse_test = list()
for i in range(len(validate_x_y)):
rmse_validate.append(float(RMSELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
rmse_test.append(float(RMSELoss(test_hat[i],test_real[i])))
print ("rmse_validate mean", np.mean(rmse_validate))
print ("rmse_test mean", np.mean(rmse_test))
#compute MAE
mae_validate = list()
mae_test = list()
for i in range(len(validate_x_y)):
mae_validate.append(float(MAELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
mae_test.append(float(MAELoss(test_hat[i],test_real[i])))
print ("mae_validate mean", np.mean(mae_validate))
print ("mae_test mean", np.mean(mae_test))
#show RMSE and MAE together
mae_validate, rmse_validate, mae_test, rmse_test =\
np.array(mae_validate)*infection_normalize_ratio, np.array(rmse_validate)*infection_normalize_ratio,\
np.array(mae_test)*infection_normalize_ratio, np.array(rmse_test)*infection_normalize_ratio
print ("-----------------------------------------")
print ("mae_validate mean", round(np.mean(mae_validate),3), " rmse_validate mean", round(np.mean(rmse_validate),3))
print ("mae_test mean", round(np.mean(mae_test),3), " rmse_test mean", round(np.mean(rmse_test),3))
print ("-----------------------------------------")
rmse_validate mean 0.030588341888637016 rmse_test mean 0.1343292538425676 mae_validate mean 0.023897380344336094 mae_test mean 0.10747848775348863 ----------------------------------------- mae_validate mean 2.39 rmse_validate mean 3.059 mae_test mean 10.748 rmse_test mean 13.433 -----------------------------------------
print(validate_hat[0][Y_day-1])
print(torch.sum(validate_hat[0][Y_day-1]))
print(validate_real[0][Y_day-1])
print(torch.sum(validate_real[0][Y_day-1]))
tensor([0.0167, 0.0461, 0.0663, 0.0839, 0.0381, 0.0427, 0.0518, 0.0574, 0.0939,
0.0629, 0.1370, 0.1398, 0.0572, 0.0740, 0.0516, 0.0396, 0.0362, 0.0339,
0.0529, 0.1124, 0.0864, 0.0528, 0.0809], grad_fn=<SelectBackward>)
tensor(1.5144, grad_fn=<SumBackward0>)
tensor([0.0100, 0.0357, 0.0771, 0.1000, 0.0243, 0.0371, 0.0300, 0.0329, 0.0671,
0.0443, 0.1800, 0.1643, 0.0514, 0.0700, 0.0743, 0.0286, 0.0443, 0.0343,
0.0543, 0.0757, 0.1014, 0.0457, 0.0486], dtype=torch.float64)
tensor(1.4314, dtype=torch.float64)
x = range(len(rmse_validate))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_validate), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_validate), 'go-',linewidth=0.8, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of validation',fontsize=12)
plt.ylabel("RMSE/MAE daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
x = range(len(mae_test))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_test), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_test), 'go-',linewidth=0.5, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of test',fontsize=12)
plt.ylabel("RMSE/MAE Daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
from scipy import stats
#validate
y_days = Y_day
validate_hat_sum = [float(torch.sum(validate_hat[i][y_days-1])) for i in range(len(validate_hat))]
validate_real_sum = [float(torch.sum(validate_real[i][y_days-1])) for i in range(len(validate_real))]
print ("the correlation between validation: ", stats.pearsonr(validate_hat_sum, validate_real_sum)[0])
#test
test_hat_sum = [float(torch.sum(test_hat[i][y_days-1])) for i in range(len(test_hat))]
test_real_sum = [float(torch.sum(test_real[i][y_days-1])) for i in range(len(test_real))]
print ("the correlation between test: ", stats.pearsonr(test_hat_sum, test_real_sum)[0])
#train
train_result, train_hat, train_real = validate_test_process(trained_model, train_x_y)
train_hat_sum = [float(torch.sum(train_hat[i][0])) for i in range(len(train_hat))]
train_real_sum = [float(torch.sum(train_real[i][0])) for i in range(len(train_real))]
print ("the correlation between train: ", stats.pearsonr(train_hat_sum, train_real_sum)[0])
the correlation between validation: 0.84419518060181 the correlation between test: 0.2713049497993222 the correlation between train: 0.9423266874790944
y1List = [np.sum(list(train_original[i+1][1][Y_day-1].values())) for i in range(len(train_original)-1)]
y2List = [np.sum(list(validate_original[i][1][Y_day-1].values())) for i in range(len(validate_original))]
y2List_hat = [float(torch.sum(validate_hat[i][Y_day-1])) for i in range(len(validate_hat))]
y3List = [np.sum(list(test_original[i][1][Y_day-1].values())) for i in range(len(test_original))]
y3List_hat = [float(torch.sum(test_hat[i][Y_day-1])) for i in range(len(test_hat))]
#x1 = np.array(range(len(y1List)))
#x2 = np.array([len(y1List)+j for j in range(len(y2List))])
x1 = train_list
x2 = validation_list
x3 = np.array([len(y1List)+len(y2List)+j for j in range(len(y3List))])
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x1[0: len(y1List)], np.array(y1List)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='train')
l2 = plt.plot(x2, np.array(y2List)*infection_normalize_ratio, 'go-',linewidth=0.8, markersize=2.0, label='validate')
l3 = plt.plot(x2, np.array(y2List_hat)*infection_normalize_ratio, 'g-',linewidth=2, markersize=0.1, label='validate_predict')
l4 = plt.plot(x3, np.array(y3List)*infection_normalize_ratio, 'bo-',linewidth=0.8, markersize=2, label='test')
l5 = plt.plot(x3, np.array(y3List_hat)*infection_normalize_ratio, 'b-',linewidth=2, markersize=0.1, label='test_predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(5):
summary += 60
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize=8)
plt.yticks(fontsize=12)
plt.title("SpectralGCN")
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
def getPredictionPlot(k):
#location k
x_k = [i for i in range(len(test_real))]
real_k = [test_real[i][y_days-1][k] for i in range(len(test_real))]
predict_k = [test_hat[i][y_days-1][k] for i in range(len(test_hat))]
plt.figure(figsize=(4,2.5), dpi=300)
l1 = plt.plot(x_k, np.array(real_k)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='real',alpha = 0.8)
l2 = plt.plot(x_k, np.array(predict_k)*infection_normalize_ratio, 'o-',color='black',linewidth=0.8, markersize=2.0, alpha = 0.8, label='predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
#plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,100,40)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(6):
summary += 10
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.title("Real and predict daily infection for region "+str(k))
plt.legend()
plt.grid()
#plt.savefig('sg_peak4_21_21_1feature_0005.pdf',bbox_inches = 'tight')
plt.show()
for i in range(23):
getPredictionPlot(i)